--- external/eg3d/training/superresolution.py	2023-03-14 23:39:22.320046811 +0000
+++ external_reference/eg3d/training/superresolution.py	2023-03-14 23:28:05.379774778 +0000
@@ -12,14 +12,14 @@
 "Efficient Geometry-aware 3D Generative Adversarial Networks"."""
 
 import torch
-from training.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer
+from external.stylegan.training.networks_stylegan2_terrain import Conv2dLayer, SynthesisLayer, ToRGBLayer
 from torch_utils.ops import upfirdn2d
 from torch_utils import persistence
 from torch_utils import misc
 
-from training.networks_stylegan2 import SynthesisBlock
+from external.stylegan.training.networks_stylegan2_terrain import SynthesisBlock
 import numpy as np
-from training.networks_stylegan3 import SynthesisLayer as AFSynthesisLayer
+from external.stylegan.training.networks_stylegan3_sky import SynthesisLayer as AFSynthesisLayer
 
 
 #----------------------------------------------------------------------------
@@ -57,25 +57,29 @@
 
 #----------------------------------------------------------------------------
 
-# for 256x256 generation
+# for 256x256 generation -- modified to support RGBD input and output in to_rgb
+# branch and to ignore style code input
 @persistence.persistent_class
 class SuperresolutionHybrid4X(torch.nn.Module):
     def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
                 num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
-                **block_kwargs):
+                ignore_w=True, **block_kwargs):
         super().__init__()
         assert img_resolution == 256
         use_fp16 = sr_num_fp16_res > 0
         self.sr_antialias = sr_antialias
         self.input_resolution = 128
         self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128,
-                img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+                img_channels=4, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
         self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256,
-                img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+                img_channels=4, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+        self.ignore_w = ignore_w
         self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
 
     def forward(self, rgb, x, ws, **block_kwargs):
         ws = ws[:, -1:, :].repeat(1, 3, 1)
+        if self.ignore_w:
+            ws = torch.zeros_like(ws) # input to affine layer with bias=1
 
         if x.shape[-1] < self.input_resolution:
             x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
@@ -289,4 +293,4 @@
         x, rgb = self.block1(x, rgb, ws, **block_kwargs)
         return rgb
 
-#----------------------------------------------------------------------------
\ No newline at end of file
+#----------------------------------------------------------------------------
